Skip to content

Conversation

@amitsrivastava78
Copy link
Collaborator

Supports following feature

  • Asynchronous Checkpointing
  • Composite Checkpointing
  • Preservation Policies
  • Save Decision Policies
  • Transformations - Custom Handlers

…re Supports following feature - Asynchronous Checkpointing - Composite Checkpointing - Preservation Policies - Save Decision Policies - Transformations - Custom Handlers
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates Orbax checkpointing into Keras 3.0, providing a robust and flexible mechanism for saving and restoring training progress. The new OrbaxCheckpoint callback offers features like asynchronous saving, customizable save policies, and the ability to save complex states including model weights, optimizer variables, metrics, and data iterator positions. This enhancement significantly improves the reliability and efficiency of training large models, especially in distributed environments, by leveraging Orbax's advanced capabilities.

Highlights

  • New OrbaxCheckpoint Callback: Introduces a new OrbaxCheckpoint callback for Keras 3.0, enabling advanced data-centric saving and restoration of model states.
  • Asynchronous Checkpointing: Supports asynchronous saving of model weights and optimizer states, allowing training to continue without I/O blocking.
  • Comprehensive Checkpointing Features: Includes support for composite checkpointing, preservation policies (e.g., max_to_keep, keep_period), save decision policies (e.g., save_interval), and custom transformations during saving.
  • Distributed Training Support: Adds a get_process_index utility function to the Keras backend, facilitating distributed training setups by identifying the primary process for checkpoint operations across JAX, TensorFlow, and PyTorch.
  • Extensible with Custom Handlers: Exposes advanced Orbax functionalities like CheckpointManager, TypeHandler, and register_type_handler to allow users to define custom serialization logic for complex objects.
  • Iterator State Saving and Restoration: Enables saving and restoring the state of data iterators, crucial for seamless training resumption from a specific point, with backend-specific examples for TensorFlow, JAX, and PyTorch.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces OrbaxCheckpoint, a new Keras callback for advanced checkpointing using the Orbax library. This is a significant feature addition that enables asynchronous saving, composite checkpoints, and other powerful capabilities. The implementation is extensive and is supported by a comprehensive suite of tests.

My review has identified several important issues that need attention. There are critical correctness and performance bugs in the main implementation: the batch-based saving logic is flawed, and the asynchronous saving feature is effectively disabled by blocking calls. Additionally, some features are incomplete, and there are minor areas for improvement in the tests to enhance maintainability. I have provided specific suggestions to address these points. After these fixes, this will be a very valuable addition to Keras.

Comment on lines 119 to 141
def __init__(
self,
directory,
monitor="val_loss",
verbose=0,
save_best_only=False,
mode="auto",
save_freq="epoch",
max_to_keep=5,
keep_period=None,
initial_value_threshold=None,
save_optimizer_state=True,
save_on_background=True,
save_metadata=None,
save_data_iterator=None,
save_metrics_state=False,
async_timeout_secs=600,
enable_background_delete=False,
post_finalization_callback=None,
save_transforms=None,
save_decision_policy=None,
save_interval=None,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ method has 16 arguments, which is quite high. The Keras API design guidelines suggest reconsidering signatures with more than 6-7 arguments.1 While I understand the need to expose Orbax's functionality, it might be worth exploring if some of these could be grouped into a configuration object to improve readability and usability, similar to how ocp.CheckpointManagerOptions is used internally.

Style Guide References

Footnotes

  1. The style guide recommends that functions with more than 6-7 arguments should be re-evaluated for simplification, possibly by breaking them into smaller objects or modular pieces.

@codecov-commenter
Copy link

codecov-commenter commented Oct 22, 2025

Codecov Report

❌ Patch coverage is 78.17259% with 43 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.65%. Comparing base (47fcb39) to head (33f4e66).
⚠️ Report is 38 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/callbacks/orbax_checkpoint.py 77.77% 23 Missing and 19 partials ⚠️
keras/api/_tf_keras/keras/callbacks/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21762      +/-   ##
==========================================
- Coverage   82.69%   82.65%   -0.04%     
==========================================
  Files         573      578       +5     
  Lines       58888    59670     +782     
  Branches     9218     9374     +156     
==========================================
+ Hits        48696    49319     +623     
- Misses       7845     7929      +84     
- Partials     2347     2422      +75     
Flag Coverage Δ
keras 82.47% <77.66%> (-0.03%) ⬇️
keras-jax 63.35% <76.14%> (+0.11%) ⬆️
keras-numpy 57.41% <17.25%> (-0.31%) ⬇️
keras-openvino 34.28% <17.25%> (-0.12%) ⬇️
keras-tensorflow 64.15% <72.58%> (+0.14%) ⬆️
keras-torch 63.64% <72.58%> (+0.07%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. This checkpointing system has a ton of features!

Quick first pass.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple more comments I forgot.

- Remove conditional export decorator to ensure OrbaxCheckpoint is always available
- Remove unnecessary exception handling in state tree operations
- Update process index check comment for clarity
- Format code to comply with 80-character line limit
- Add distribution_lib modules for backend-specific distributed training support
- Remove unused 'result' variable in _reconstruct_state_tree_with_values
- Fix long comment line in test file
- Apply code formatting changes
…st handling

- Implement OrbaxCheckpoint callback for async checkpointing with state tree handling
- Add conditional exports for optional orbax-checkpoint dependency
- Use pytest.importorskip for clean optional dependency testing
- Ensure graceful handling when orbax-checkpoint is not installed
Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The JAX implementation of def process_id() is missing.

General questions:

  • Does this as-is support all backends?
  • Does this support JAX sharding? I don't see anything related to sharing (which may be normal). What about re-sharding?

- Preserve nested state tree structures instead of flattening for better layer name preservation
- Add backward compatibility for old flattened format checkpoints
- Simplify test class by using self.get_temp_dir() instead of setUp/tearDown
- Remove silent pytest.importorskip, add explicit skip conditions for backend-specific tests
- Move process_id function from backend to distribution module
- Update imports to use centralized LazyModule for orbax.checkpoint
- Test across all backends (JAX, TensorFlow, PyTorch) - all passing
checkpoints if there might be pending save operations.
"""
# Wait for any async operations to complete
while self.checkpointer.is_saving_in_progress():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably better to use checkpointer.wait() here, unless you want to log things periodicially.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using Orbax's experimental v1 API, but the checkpointer.wait() method doesn't exist in our installed version. The GitHub link you provided might be from a development branch, i can see it here https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py#L521 but not on the installed released version.

@amitsrivastava78 amitsrivastava78 force-pushed the orbax-checkpoint-test-improvements branch from 621f566 to eb7855d Compare November 10, 2025 09:45
…s expected failures

Neural networks are inherently non-deterministic, so pipeline consistency
checks should be skipped rather than fail. Added check_pipeline_consistency
to EXPECTED_FAILED_CHECKS for all sklearn wrapper types.
- Avoid unnecessary numpy conversion in _get_state_tree() for JAX backend
- Preserve JAX arrays during saving instead of converting to numpy
- Maintain cross-backend compatibility with proper loading conversions
- Update async waiting to use CheckpointManager.wait_until_finished()
- Implement AlwaysSavePolicy for reliable save decisions
- Add expected failures for sklearn tests due to neural network non-determinism
@amitsrivastava78 amitsrivastava78 force-pushed the orbax-checkpoint-test-improvements branch from c14c30e to b7a0dff Compare November 11, 2025 05:54
- Preserve JAX arrays during saving when jax.monitoring.record_scalar is available
- Fall back to numpy conversion for older JAX versions that don't have record_scalar
- Maintain cross-backend compatibility while avoiding unnecessary conversions
- Update async waiting to use CheckpointManager.wait_until_finished()
- Implement AlwaysSavePolicy for reliable save decisions
- Add expected failures for sklearn tests due to neural network non-determinism
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants